【从0开始】使用Flax NNX API 构建简单神经网络并训练

  • 1 min read

与 Linen API 不同,NNX 使用起来对初学者更加简单,跟 PyTorch 的体验更加接近。

任务

使用MLP拟合简单函数: $$ y=2x^2+1 $$

代码

from typing import Generator

import jax.numpy as jnp
import jax.random as jrm
import optax as ox
from flax import nnx
from jax import Array


class Network(nnx.Module):
    """def a simple MLP"""

    def __init__(self, in_dim: int, out_dim: int, rng: nnx.Rngs, hidden_dim: int):
        super().__init__()
        self.linear1 = nnx.Linear(in_dim, hidden_dim, rngs=rng)
        self.linear2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rng)
        self.linear3 = nnx.Linear(hidden_dim, out_dim, rngs=rng)

    def __call__(self, x) -> Array:
        x = self.linear1(x)
        x = nnx.relu(x)
        x = self.linear2(x)
        x = nnx.relu(x)
        x = self.linear3(x)
        return x


def make_dataset(
    X: Array, Y: Array, batch: int, key
) -> Generator[tuple[Array, Array, Array], None, None]:
    "dataset sample function"
    combined = jnp.stack((X, Y), axis=1)[..., None]
    while True:
        key, subkey = jrm.split(key)
        selected = jrm.choice(subkey, combined, shape=(batch,))
        yield selected[:, 0], selected[:, 1], key


def loss_fn(model: Network, batch):
    x, y = batch
    predicted = model(x)
    return ox.l2_loss(predicted, y).mean()


@nnx.jit
def train_step(model: Network, optimizer: nnx.Optimizer, batch):
    loss, grads = nnx.value_and_grad(loss_fn)(model, batch)
    optimizer.update(grads)
    return loss


# hyper parameter
seed = 0
batch = 32

# make dataset
X = jnp.arange(0, 10, 0.005)
Y = 2 * X**2 + 1.0

# build model & optimizer
model = Network(1, 1, hidden_dim=20, rng=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, ox.adamw(0.001))

# train
key = jrm.key(seed)
for i, (x, y, _) in enumerate(make_dataset(X, Y, batch, key)):
    loss = train_step(model, optimizer, (x, y))
    print(i, loss)
    if i >= 6000:
        break

依赖如下

$ uv pip list
Package              Version
-------------------- --------
absl-py              2.4.0
aiofiles             25.1.0
box2d                2.3.10
cloudpickle          3.1.2
colorama             0.4.6
docstring-parser     0.17.0
etils                1.13.0
farama-notifications 0.0.4
flax                 0.10.7
fsspec               2026.2.0
gymnasium            1.2.3
humanize             4.15.0
importlib-resources  6.5.2
jax                  0.7.2
jaxlib               0.7.2
markdown-it-py       4.0.0
mdurl                0.1.2
ml-dtypes            0.5.1
msgpack              1.1.2
nest-asyncio         1.6.0
numpy                2.2.2
opt-einsum           3.4.0
optax                0.2.7
orbax-checkpoint     0.11.32
packaging            26.0
protobuf             6.33.5
psutil               7.2.2
pygame               2.6.1
pygments             2.19.2
pyyaml               6.0.3
rich                 14.3.3
scipy                1.15.1
simplejson           3.20.2
swig                 4.4.1
tensorboardx         2.6.4
tensorstore          0.1.81
tqdm                 4.67.3
treescope            0.1.10
typeguard            4.5.1
typing-extensions    4.15.0
tyro                 1.0.7
zipp                 3.23.0